% Copyright (C) 2019-2023 Benjamin Born, Francesco D'Ascanio, Gernot J. Mueller, Johannes Pfeifer
%
% This is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% It is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
% 
% For a copy of the GNU General Public License,
% see <http://www.gnu.org/licenses/>.

% Use a policy-function-iteration procedure to obtain the policy function under currency peg
%(c) Johannes Pfeifer based on code by Stephanie Schmitt-Grohe and Martin Uribe, 2010.

clear all
addpath(['Auxiliary_functions/'])
start_from_previous_run=0;

if ~isfolder('Policy_functions')
    mkdir('.','Policy_functions');
end

[MC, par, grid]=get_parameterization;

narrow_window_dp_size = 21;  %must be odd; use the moving narrow window technique developed in this program to handle large-dimension state-space problems

r_grid = repmat(grid.r_level,[grid.n_d*grid.n_w narrow_window_dp_size]);
g_values=unique(grid.g_level);

n = grid.n_y*grid.n_d*grid.n_w; %number of states

filename = ['Policy_functions/pfi_peg_g_low_beta_large_grid'];

lambda = ones(grid.n_y,grid.n_d,grid.n_w)*0.5;

%create initial guess for d_t+1 on narrow window
dp_try_index=NaN(grid.n_d,narrow_window_dp_size);
shift_n_dp_try = (narrow_window_dp_size-1)/2;
if ~start_from_previous_run    
    for state_iter=1:grid.n_d
        if state_iter>shift_n_dp_try && state_iter<=grid.n_d-shift_n_dp_try % interior window: use diagonal
            dp_try_index(state_iter,:) = (state_iter-shift_n_dp_try:state_iter+shift_n_dp_try);
        elseif state_iter<=shift_n_dp_try %lower corner: use same linear state choice from 1 to narrow window length
            dp_try_index(state_iter,:) = (1:1+2*shift_n_dp_try);
        else %upper corner: use same linear state choice from n_d-narrow window length to n_d
            dp_try_index(state_iter,:) = (grid.n_d-2*shift_n_dp_try:grid.n_d);
        end %if i<...
    end %for i=1:nd
    dp_try_index = kron(dp_try_index,ones(grid.n_y,1));
    dp_try_index = repmat(dp_try_index,[grid.n_w 1]);   
else    
    old_run=load('Policy_functions/pfi_peg_g_low_beta_large_grid.mat','MC','dp_index','lambda');
    k = dsearchn(old_run.MC.S,MC.S);
    dp_index_old=old_run.dp_index(k,:,:);
    dp_index_old=dp_index_old(:);
    dp_try_index(dp_index_old<narrow_window_dp_size,:)=repmat((1:narrow_window_dp_size),sum(dp_index_old<narrow_window_dp_size),1);
    dp_try_index(dp_index_old>grid.n_d-narrow_window_dp_size,:)=repmat((grid.n_d-narrow_window_dp_size+1:grid.n_d),sum(dp_index_old>grid.n_d-narrow_window_dp_size),1);
    temp_indices=find(~(dp_index_old>grid.n_d-narrow_window_dp_size) & ~ (dp_index_old<narrow_window_dp_size));
    dp_try_index(temp_indices,:)=...
        dp_index_old(temp_indices)-shift_n_dp_try+repmat(-shift_n_dp_try:shift_n_dp_try,length(temp_indices),1);    
    clear('k','dp_index_old','temp_indices','old_run')
end

% Initialize Dummies for whether bounds on narrow grid window are hit
not_hit_high = 0; % not_hit_high=1 if the solution for dp does not involve picking the upper bound of dptry or it does but dp is dupper, and zero otherwise;
not_hit_low = 0; % not_hit_high=1 if the solution for dp does not involve picking the lower bound of dptry or it does but dp is dlower, and zero otherwise;

iter = 0;

max_iter = 10000;

yT_grid = repmat(grid.y_level(:),[grid.n_d*grid.n_w narrow_window_dp_size]); %build yT grid

options = optimoptions(@fsolve,'Algorithm','trust-region-reflective','SpecifyObjectiveGradient',true,...
    'MaxIterations',400,...
    'StepTolerance',1e-12,...
    'CheckGradients',false,...
    'Display','off');

overall_iter_number=1;
while not_hit_high*not_hit_low*... %loop for adjusting narrow window. Stop if window bounds are not hit
        (iter-max_iter) == 0 %and time iteration did not hit upper bound of allowed iterations
    %for all choices for d_t+1 contained in narrow window, compute wage,
    %i.e. deal with wage constraint; required for lambda_t+1 as w_t will be
    %state variable then
    dp_try_index_last_column_equal_nd = (dp_try_index(:,end)==grid.n_d); %does last point in narrow window imply binding BC
    
    d_try = repmat(grid.d_grid',[grid.n_y 1 grid.n_w narrow_window_dp_size]); %4D-array of debt today on n_y*n_d*n_w*dp array
    d_try = reshape(d_try,n,narrow_window_dp_size);   % reshape to n*narrow window grid
    
    cT_try = yT_grid + grid.d_grid(dp_try_index)./(1+r_grid) - d_try;% consumption of tradables
    clear('d_try'); %keep memory use down
    if min(max(cT_try,[],2))<0
        warning('Natural debt limit violated')
    end
    
    w_grid = repmat(grid.w_grid_vec',[grid.n_y*grid.n_d 1]);
    w_grid = repmat(w_grid(:),[1 narrow_window_dp_size]);

    g_rule_grid=repmat(grid.g_level,[grid.n_d*grid.n_w narrow_window_dp_size])+par.phi_w*(w_grid/2.6-1);
    wfull_try =  (1-par.omega)/par.omega*((par.hbar^par.alfa-g_rule_grid)./abs(cT_try)).^(-1/par.xi)*par.alfa*par.hbar^(par.alfa-1); %full-employment real wage; negative c_T filtered out below
    w_try = max(par.gama*w_grid,wfull_try); %real wage given downward constraint
    clear('wfull_try') %keep memory use down
%     clear('wfull_try','w_grid') %keep memory use down
    % force onto grid
    w_try(w_try<grid.w_grid_vec(1)) = grid.w_grid_vec(1);
    w_try(w_try>grid.w_grid_vec(end)) = grid.w_grid_vec(end);
    
    w_try_index = round((log(w_try)-log(grid.w_grid_vec(1)))/grid.w_grid_step_size)+1;
    w_try = grid.w_grid_vec(w_try_index);
    
    h_try=NaN(size(cT_try));
    
    n_grid_h=400;
    n_grid_g_rule=400;
    g_rule_small_grid=linspace(min(g_rule_grid(:)),max(g_rule_grid(:)),n_grid_g_rule)';

    selector_matrix=find(cT_try>0);
    const=log(w_try(selector_matrix)*par.omega/(1-par.omega)/par.alfa./(cT_try(selector_matrix).^(1/par.xi)));
    const_grid=linspace(min(const),max(const),n_grid_h)';
    [const_mat,G_mat]=ndgrid(const_grid,g_rule_small_grid);

    [h_const_grid,fval,exitflag]=fsolve(@h_foc_log_const_sparse,0.1*ones(n_grid_h*n_grid_g_rule,1),options,par.alfa,G_mat(:),par.xi,const_mat(:));
    if max(abs(imag(h_const_grid)))>1e-5
        error('Complex values encountered')
    else
        h_const_grid=real(h_const_grid);
        temp=reshape(h_const_grid,n_grid_h,n_grid_g_rule);
%         mesh(const_mat,G_mat,temp);
        h_interp=griddedInterpolant(const_mat,G_mat,temp,'spline'); % create interpolation object
        h_full=h_interp(const,g_rule_grid(selector_matrix));
    end
    resid=log(w_try(selector_matrix)*par.omega/(1-par.omega)/par.alfa) -...
        (-1/par.xi)*log((h_full.^par.alfa-g_rule_grid(selector_matrix))./cT_try(selector_matrix))-(par.alfa-1)*log(h_full);
    %plot(resid)
    if max(abs(resid))>1e-5
        max(abs(resid))
        error('FOC does not hold')
    else
        h_try(selector_matrix)=h_full;
    end


    clear('const','const_grid','const_mat','G_mat','g_rule_small_grid','temp','resid','fval','exitflag','h_const_grid','h_interp','h_full','w_try','selector_matrix')

    h_try = min(h_try,par.hbar); %hours
    
    c_try = (par.omega*abs(cT_try).^(1-1/par.xi) + (1-par.omega)*(h_try.^par.alfa-g_rule_grid).^(1-1/par.xi)).^(1/(1-1/par.xi)); %composite consumption
    
    %Marginal utility of consumption of tradables (lambda_t)
    lambda_try = par.omega*abs(cT_try).^(-1/par.xi).*c_try.^(1/par.xi-par.sigg);
    
    lambda_try(cT_try<0) = -inf;
    
    clear('wfull_try','c_try','cT_try')
    %done with constructing lambda_try
    
    % for any choice of d_t+1 in narrow window, d_s, and resulting choice for
    % w_t, w_qq(d_s), get selection index for resulting
    % lambda_t+1(y,r,d_s,w_qq(d_s)), where at time t+1 d_s and w_qq(d_s) are now states
    lambdap_index_try = sub2ind(size(lambda),... %array size n_y*n_d*n_w
        repmat((1:grid.n_y)',[grid.n_d*grid.n_w narrow_window_dp_size]),grid.n_d*(w_try_index-1)+dp_try_index);
%     save('w_try_index','w_try_index','-v7.3');
%     clear('w_try_index');
    lambda = lambda(:);
    
    dp_narrow_window_index_old = zeros(n,1);
    %dp stands for debt chosen today and due next period
    %fp stands for fewer points (we look only at a window around the current debt level, as opposed to at the entire debt grid)
    
    dist = 1; iter=0;
    
    while dist> 1e-8 && iter<max_iter %main iteration over Euler equation, taken narrow window as given
        tic
        iter=iter+1
        Euler_try=MC.P_trans*reshape(lambda,grid.n_y,grid.n_d*grid.n_w); %E_t(lambda_t+1) for all n_y*n_d*n_w
        
        Euler_try = Euler_try(lambdap_index_try); %E_t(lambda_t+1(y,r,d_s,w_qq(d_s)))
        Euler_try = lambda_try./(1+r_grid)-par.betta*Euler_try;% Euler error, i.e. Lagrange multiplier mu
        
        [~, dp_narrow_window_index] = min(abs(Euler_try),[],2); %find point for d_t+1 that minimizes Euler error
        
        dp_narrow_window_index(Euler_try(:,end) > 0 & dp_try_index_last_column_equal_nd) =...
            narrow_window_dp_size; % BC binding: mu is positive and debt choice is maximum
        
        lambda_new = lambda_try(sub2ind(size(lambda_try), (1:n)', dp_narrow_window_index)); %update lambda_t
        dist = max(abs(lambda_new-lambda))
        plot(lambda_new-lambda)
        drawnow
        share_of_updated_d_choices=mean(dp_narrow_window_index_old~=dp_narrow_window_index)
        
        lambda = lambda_new; %update lambda
        dp_narrow_window_index_old = dp_narrow_window_index; %update index
        toc
    end %while dist>1e-8
    if iter==max_iter
        error('Inner loop did not converge')
    end
    clear('Euler_try','dp_narrow_window_index_old','lambda_try','lambdap_index_try','lambda_new')
    lambda = reshape(lambda,grid.n_y,grid.n_d,grid.n_w);
    
    %Update debt choice matrix for d_t+1 based on previous iteration
    dp = reshape(grid.d_grid(dp_try_index(sub2ind(size(dp_try_index),(1:n)',dp_narrow_window_index))),grid.n_y,grid.n_d,grid.n_w);
    
    %Check whether subgrid is too small
    %Recall: nd is grid size for d
    %        ndptry is grid size for dptry
    %Step 1: Check whether when dpfpix = ndptry, ie when it  picks the upper limit of the
    %smaller grid, there are points in the true grid d that could have been
    %chosen
    
    a_up=find(dp_narrow_window_index==narrow_window_dp_size & dp(:)~=grid.d_grid(end));%find cases when the upper bound of narrow window is chosen and not equal to debt limit
    %if a_up is empty (that is you never hit the upper bound or you hit it when dp=dupper)
    if  isempty(a_up)
        not_hit_high = 1 %1==> no problem; 0==>problem
    else %update dp_try_index
        not_hit_high = 0
        dp_try_index(a_up,:) = (dp_try_index(a_up,:)+shift_n_dp_try).*... %move window up by shift_n_dp_try
            repmat(dp_try_index(a_up,end)<=grid.n_d-shift_n_dp_try,[1 narrow_window_dp_size]) ... %if not yet at upper end
            + repmat(grid.n_d-narrow_window_dp_size+1:grid.n_d,[numel(a_up) 1]).*... %choose highest possible window
            (1-repmat(dp_try_index(a_up,end)<=grid.n_d-shift_n_dp_try,[1 narrow_window_dp_size])); %if shift by sn_dptry would not be possible
    end %if isempty(a_up)
    
    %now check the case that dpfpix =1; Is dptry=dlower?
    a_lo=find(dp_narrow_window_index==1 & dp(:)~=grid.d_grid(1));%find index of dptry grid at lower bound and not equal to dlower
    if isempty(a_lo)
        not_hit_low = 1 %1==> no problem; 0==>problem
    else %update dp_try_index
        not_hit_low = 0
        dp_try_index(a_lo,:) = (dp_try_index(a_lo,:)-shift_n_dp_try).*...%move window down by shift_n_dp_try
            repmat(dp_try_index(a_lo,1)>shift_n_dp_try,[1 narrow_window_dp_size]) ... %if not yet at lower end
            + repmat(1:narrow_window_dp_size,[numel(a_lo) 1]).* ... %choose lowest possible window
            (1-repmat(dp_try_index(a_lo,1)>shift_n_dp_try,[1 narrow_window_dp_size])); %if shift by sn_dptry would not be possible
    end %if isempty(a_lo)
    overall_iter_number=overall_iter_number+1
end %while hit_high*hit_low == 0
if iter==max_iter
    error('Inner loop did not converge')
end

clear('dp_try_index_last_column_equal_nd','yT_grid','share_of_updated_d_choices','r_grid','options','w_grid','g_rule_grid')

dp_index=reshape(dp_try_index(sub2ind(size(dp_try_index),(1:n)',dp_narrow_window_index)),grid.n_y,grid.n_d,grid.n_w);

disp('Compute policy functions, dp and w')

% load('w_try_index','w_try_index');
w_index = w_try_index(sub2ind(size(dp_try_index),(1:n)',dp_narrow_window_index));
clear('w_try_index');

w = grid.w_grid_vec(reshape(w_index,grid.n_y,grid.n_d,grid.n_w));

cT = repmat(grid.y_level,[1 grid.n_d grid.n_w]) + dp./(1+repmat(grid.r_level,[1 grid.n_d grid.n_w])) - repmat(grid.d_grid',[grid.n_y 1 grid.n_w]);% consumption of tradables
c = ((lambda.*cT.^(1/par.xi))/par.omega).^(1/(1/par.xi-par.sigg)); %consumption

h = reshape(h_try(sub2ind(size(dp_try_index),(1:n)',dp_narrow_window_index)),grid.n_y,grid.n_d,grid.n_w);

% resid=log(w*par.omega/(1-par.omega)/par.alfa) -...
%         (-1/par.xi)*log((h.^par.alfa-repmat(g_grid_level,[1 grid.n_d grid.n_w]))./cT)-(par.alfa-1)*log(h);

disp('Compute period utility, u')
    
u = (c.^(1-par.sigg)-1)/ (1-par.sigg); %period utility
u = u(:);

%Compute the Value Function
v = zeros(n,1);
distv = 1;
I = sub2ind(size(reshape(v,grid.n_y,grid.n_d,grid.n_w)),repmat((1:grid.n_y)',[grid.n_d*grid.n_w 1]),grid.n_d * (w_index-1)+ dp_try_index(sub2ind(size(dp_try_index), (1:n)', dp_narrow_window_index)));
while distv>1e-8
    v_temp=MC.P_trans * reshape(v,grid.n_y,grid.n_d*grid.n_w);
    v1 = u + par.betta *v_temp(I);
    distv = max(abs(v-v1))
    v = v1;
end %while distv

v = reshape(v,grid.n_y,grid.n_d,grid.n_w);

clear('v_temp','I','v1','distv','u','dp_try_index','a_lo','a_up','g_grid','h_try','dp_narrow_window_index')
% delete('w_try_index.mat');

save([filename '.mat'],'-v7.3')% filename rstar h_try d_grid w_grid_vec dp w v not_hit_high not_hit_low narrow_window_dp_size h dp_index lambda'])